import random
from tqdm import tqdm
import torch
import json
import os


def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
    return sentence_embeddings


def get_sent_embeddings(sents, contriever, tok, BSZ=32):
    all_embs = []
    for i in tqdm(range(0, len(sents), BSZ), disable=True):
        sent_batch = sents[i:i + BSZ]
        inputs = tok(sent_batch, padding=True, truncation=True, return_tensors='pt').to("cuda")
        with torch.no_grad():
            outputs = contriever(**inputs)
            embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
        all_embs.append(embeddings.cpu())
    all_embs = torch.vstack(all_embs)
    return all_embs


def retrieve_facts(query, fact_embs, contriever, tok, k=1):
    inputs = tok([query], padding=True, truncation=True, return_tensors='pt').to("cuda")
    with torch.no_grad():
        outputs = contriever(**inputs)
        query_emb = mean_pooling(outputs[0], inputs['attention_mask']).cpu()
    sim = (query_emb @ fact_embs.T)[0]
    knn = sim.topk(k, largest=True)
    
    return knn.indices


def call_model(prompt, stop, model, gptj_tokenizer, device='cuda', generate_length=50, temperature=1.0):
    encoding = gptj_tokenizer(prompt, return_tensors="pt")
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    gen_tokens = model.generate(
        input_ids,
        do_sample=True,
        attention_mask=attention_mask,
        max_length=len(input_ids[0]) + generate_length,
        stopping_criteria=stop,
        temperature=temperature,
        pad_token_id=gptj_tokenizer.eos_token_id
    )
    gen_text = gptj_tokenizer.batch_decode(gen_tokens)[0]
    gen_text = gen_text.replace(gptj_tokenizer.eos_token, '')
    
    del input_ids, gen_tokens
    return gen_text
    
    
def call_model_template(prompt, stop, model, gptj_tokenizer, device, generate_length=50, temperature=1.0, front_space=4):
    template = '''A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT: \n'''
    input = template.format(prompt)
    encoding = gptj_tokenizer(input, return_tensors="pt")
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    gen_tokens = model.generate(
        input_ids,
        do_sample=True,
        attention_mask=attention_mask,
        max_length=len(input_ids[0]) + generate_length,
        stopping_criteria=stop,
        temperature=temperature,
        pad_token_id=gptj_tokenizer.eos_token_id
    )
    gen_text = gptj_tokenizer.batch_decode(gen_tokens)[0]
    gen_text = gen_text.replace(gptj_tokenizer.eos_token, '')
    
    del input_ids, gen_tokens

    return prompt + gen_text[front_space+len(input):]
    


def remove_extra_target_occurrences(gen, target, count):
    occurrences = gen.count(target)
    
    if occurrences <= count:
        return gen
    
    index = 0
    for _ in range(count + 1):
        index = gen.find(target, index) + len(target)
    
    return gen[:index - len(target) - 2]


def able_to_quit(gen, task_prompt):
    # The prompt contains 4 shots + 1 evaluation example -> 5.
    if gen.count('Final answer: ') >= 5:
        index = gen.find('Final answer: ', len(task_prompt)+13)
        ans = gen[index:]
        ans = ans.strip().split('\n')[0][len('Final answer: '):]
        if len(ans) > 0 and ans[-1] == '.':
            ans = ans[:-1]
        return True, ans
    else:
        return False, None


def break_down_into_subquestions(d, breakdown_prompt, sc_done, gptj_tokenizer, model, front_space=0):
    subject = d['orig']['triples_labeled'][0][0]
    retval = [subject]
    
    prompts = []
    for i in range(3):
        prompt = breakdown_prompt + f"Given this problem:\n{d['questions'][i]}\nExtract relations in square parentheses into follows:\n\"{subject}->"
        prompts.append(prompt)
    
    
    # res = call_model_batch(prompts, sc_done, temperature=0.2, model=model, gptj_tokenizer=gptj_tokenizer)
    res = []
    for i in range(3):
        res.append(call_model(prompts[i], sc_done, temperature=0.2, model=model, gptj_tokenizer=gptj_tokenizer))
            
    
    for i in range(3):
        temp = res[i].split("\n\n")[4]
        temp = temp.strip().split("\n")[-1]
        rels = extract_entities(temp)
        retval.append(rels)
        # print(temp)
    return retval[0], retval[1:]


def call_model_batch(prompts, stop, gptj_tokenizer, model, generate_length=50, temperature=1.0):
    # Tokenize the list of prompts
    input_ids = gptj_tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.cuda()
    
    # Generate text for the batch of prompts
    gen_tokens = model.generate(
        input_ids,
        do_sample=True,
        max_length=input_ids.shape[1] + generate_length,
        stopping_criteria=stop,
        temperature=temperature
    )
    
    # Decode the generated tokens to text for each prompt in the batch
    gen_texts = gptj_tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
    
    del input_ids, gen_tokens
    
    return gen_texts
    

def extract_entities(input_string):
    segments = input_string.split('->')
    
    # Initialize a list to hold the entities
    entities = []
    
    # Loop through each segment and check if it contains an entity within parentheses
    for segment in segments:
        if '(' in segment and ')' in segment:
            # Extract the entity by removing the parentheses
            entity = segment.strip()[1:-1]
            entities.append(entity)
    
    return entities


def get_ent_rel_id(file_path, dataset_name):
    # Please note the outputs "entity2id, id2entity, rel2id, id2rel" have complete information about the whole dataset respectively,
    # But these use case is subjected to the keys of kg_s_r_o, which only includes the entity in the edits of the selected edited cases.
    # There will be no leakage of information though it seems to.
    
    if dataset_name in ["CF-3k", "CF-3k-old"]:
        dataset_name = "CF"
    if dataset_name in ["CF-3151", "CF-6334"]:
        dataset_name = "CF-9k"
    if dataset_name in ['T-old']:
        dataset_name = 'T'
    
    with open(f'{file_path}/datasets/{dataset_name}/entity2id.json', 'r') as f:
        entity2id = json.load(f)
    
    with open(f'{file_path}/datasets/{dataset_name}/id2entity.json', 'r') as f:
        id2entity = json.load(f)
    
    with open(f'{file_path}/datasets/{dataset_name}/rel2id.json', 'r') as f:
        rel2id = json.load(f)
    
    with open(f'{file_path}/datasets/{dataset_name}/id2rel.json', 'r') as f:
        id2rel = json.load(f)
    return entity2id, id2entity, rel2id, id2rel


def get_ent_alias(dataset, entity2id):
    ent2alias = {}
    alias2id = {}
    for idx, d in enumerate(dataset):
        for hop in d['single_hops']:
            answer = hop['answer']
            if answer not in entity2id.keys():
                break
            answer_alias = hop['answer_alias']
            ent2alias[answer] = set(answer_alias)
            for alias in answer_alias:
                alias2id[alias] = entity2id[answer]
    
    return ent2alias, alias2id


def process_kg(dataset, rand_list, id2entity, id2rel):
    edit_kg = {}
    
    kg_s_r_o = {}
    
    rels = set()
    ents = set()
    
    for d in dataset:
        if d['case_id'] not in rand_list:
            continue
        caseid = d['case_id']
        fact_tuples = d['orig']['edit_triples']
        for index, (fact_tuple, edit) in enumerate(zip(fact_tuples, d["requested_rewrite"])):
            (s, r, o) = fact_tuple
            
            rels.add(id2rel[r])
            ents.add(id2entity[s])
            ents.add(id2entity[o])
            
            # ordinary kg construction:
            if s in edit_kg.keys():
                if o in edit_kg[s].keys():
                    if r not in edit_kg[s][o]:
                        edit_kg[s][o].add(r)
                else:
                    edit_kg[s][o] = {r}
            else:
                edit_kg[s] = {o: {r}}
            
            # test if there are sro1 and sro2 contradiction:
            if s in kg_s_r_o.keys():
                if r in kg_s_r_o[s].keys():
                    if o != kg_s_r_o[s][r][0]:
                        print("==" * 50)
                        print(f"New fact: object = {o}, caseid = {d['case_id']}")
                        print(f"{kg_s_r_o[s][r]}")
                        print("==" * 50)
                        # kg_s_r_o[s][r].add(o)
                    else:
                        temp = kg_s_r_o[s][r]
                        obj = temp[0]
                        id_set = temp[1]
                        id_set.add(caseid)
                        kg_s_r_o[s][r] = [obj, id_set]
                else:
                    kg_s_r_o[s][r] = [o, set([caseid])]
            else:
                kg_s_r_o[s] = {r: [o, set([caseid])]}
    
    return edit_kg, kg_s_r_o, list(rels), list(ents)


